import torch

from torch import nn
from torch.nn import functional as F


def get_model(model_name, train_set=None):
    if model_name == "linear":
        model = LinearRegression(input_dim=train_set[0][0].shape[0], output_dim=1)

    if model_name == "mlp":
        model = Mlp(n_classes=10, dropout=False)
    
    if model_name == "mlp_dropout":
        model = Mlp(n_classes=10, dropout=True)

    elif model_name == "small_nn":
        model = SmallNN()

    elif model_name == "conv_nn":
        model = BasicConvNN()

    elif model_name == "med_conv_nn":
        model = MedConvNN()

    else:
        print("Model not found!")

    return model

# =====================================================
# Logistic
class LinearRegression(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim, bias=False)

    def forward(self, x):
        outputs = self.linear(x)
        return outputs

# =====================================================
# MLP
class Mlp(nn.Module):
    def __init__(self, input_size=784,
                 hidden_sizes=[512, 256],
                 n_classes=10,
                 bias=True, dropout=False):
        super().__init__()

        self.dropout=dropout
        self.input_size = input_size
        self.hidden_layers = nn.ModuleList([nn.Linear(in_size, out_size, bias=bias) for
                                            in_size, out_size in zip([self.input_size] + hidden_sizes[:-1], hidden_sizes)])
        self.output_layer = nn.Linear(hidden_sizes[-1], n_classes, bias=bias)

    def forward(self, x):
        x = x.view(-1, self.input_size)
        out = x
        for layer in self.hidden_layers:
            Z = layer(out)
            out = F.relu(Z)

            if self.dropout:
                out = F.dropout(out, p=0.5)

        logits = self.output_layer(out)

        return logits


# =====================================================
# Small convolutional NN
class BasicConvNN(nn.Module):
    def __init__(self, input_size=784,
                 n_classes=10,
                 batch_norm=False,
                 bias=True, dropout=False):
        super().__init__()

        self.dropout = dropout
        self.input_size = input_size

        # Convolution 1
        self.cnn1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3))

        # Max pool 1
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)

        # Convolution 2
        self.cnn2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3))
        self.relu2 = nn.ReLU()

        # Max pool 2
        self.maxpool2 = nn.MaxPool2d(kernel_size=2)

        # Fully connected 1
        self.fc1 = nn.Linear(32 * 5 * 5, n_classes, bias=bias)
        # Fully connected 2
        # self.fc2 = nn.Linear(100, n_classes, bias=bias)

    def forward(self, x):
        # Set 1
        out = self.cnn1(x)
        out = F.relu(out)
        out = self.maxpool1(out)

        # Set 2
        out = self.cnn2(out)
        out = self.relu2(out)
        out = self.maxpool2(out)

        # Flatten
        out = out.view(out.size(0), -1)

        # Dense
        out = self.fc1(out)
        out = F.relu(out)
        # out = self.fc2(out)
        # out = F.relu(out)

        if self.dropout:
            out = F.dropout(out, p=0.5)

        return out


# Medium convolutional NN
class MedConvNN(nn.Module):
    def __init__(self, input_size=784,
                 n_classes=10,
                 bias=True, dropout=False):
        super().__init__()

        self.dropout = dropout
        self.input_size = input_size

        # Convolution 1
        self.cnn1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3))

        # Max pool 1
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)

        # Convolution 2
        self.cnn2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3))
        # Convolution 3
        self.cnn3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3))

        # Max pool 2
        self.maxpool2 = nn.MaxPool2d(kernel_size=2)

        # Fully connected 1
        self.fc1 = nn.Linear(64 * 4 * 4, 100, bias=bias)
        # Fully connected 2
        self.fc2 = nn.Linear(100, n_classes, bias=bias)

    def forward(self, x):
        # Set 1
        out = self.cnn1(x)
        out = F.relu(out)
        out = self.maxpool1(out)

        # Set 2
        out = self.cnn2(out)
        out = F.relu(out)
        out = self.cnn3(out)
        out = F.relu(out)
        out = self.maxpool2(out)

        # Flatten
        out = out.view(out.size(0), -1)

        # Dense
        out = self.fc1(out)
        out = F.relu(out)
        out = self.fc2(out)
        out = F.relu(out)

        if self.dropout:
            out = F.dropout(out, p=0.5)

        return out



# =====================================================
# Small NN
class SmallNN(nn.Module):
    def __init__(self, input_size=784,
                 hidden_sizes=[128],
                 n_classes=10,
                 bias=True, dropout=False):
        super().__init__()

        self.dropout=dropout
        self.input_size = input_size
        self.hidden_layers = nn.ModuleList([nn.Linear(in_size, out_size, bias=bias) for
                                            in_size, out_size in zip([self.input_size] + hidden_sizes[:-1], hidden_sizes)])
        self.output_layer = nn.Linear(hidden_sizes[-1], n_classes, bias=bias)

    def forward(self, x):
        x = x.view(-1, self.input_size)
        out = x
        for layer in self.hidden_layers:
            Z = layer(out)
            out = F.relu(Z)

            if self.dropout:
                out = F.dropout(out, p=0.5)

        logits = self.output_layer(out)

        return logits